# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy as np
import imgaug
import imgaug.augmenters as iaa


# class AugmenterBuilder(object):
#     def __init__(self):
#         pass

#     def build(self, args, root=True):
#         if args is None or len(args) == 0:
#             return None
#         elif isinstance(args, list):
#             if root:
#                 sequence = [self.build(value, root=False) for value in args]
#                 return iaa.Sequential(sequence)
#             else:
#                 return getattr(iaa, args[0])(
#                     *[self.to_tuple_if_list(a) for a in args[1:]])
#         elif isinstance(args, dict):
#             cls = getattr(iaa, args['type'])
#             return cls(**{
#                 k: self.to_tuple_if_list(v)
#                 for k, v in args['args'].items()
#             })
#         else:
#             raise RuntimeError('unknown augmenter arg: ' + str(args))

#     def to_tuple_if_list(self, obj):
#         if isinstance(obj, list):
#             return tuple(obj)
#         return obj


class AugmenterBuilder(object):
    def __init__(self):
        pass
    
    def make_aug(self, param):
        fun = getattr(iaa, param['type'])(**param['args'])
        p = param.get('sometimes', 1)
        if p < 1:
            fun = iaa.Sometimes(p, fun)
        return fun
    
    def build(self, augmenter_args, random_order=False, someof=0):
        fun_list = [self.make_aug(args) for args in augmenter_args]
        if someof == 0:
            return iaa.Sequential(fun_list, random_order=random_order)
        else:
            return iaa.SomeOf(someof, fun_list, random_order=random_order)


class IaaAugment():
    def __init__(self, augmenter_args=None, prob=0.5, **kwargs):
        if augmenter_args is None:
            augmenter_args = [
            {
                'type': 'Fliplr',
                'args': {'p': 0.5}
            },
            {
                'type': 'Affine',
                'args': {'rotate': [-10, 10]}
            }, 
            {
                'type': 'Resize',
                'args': {'size': [0.5, 3.0]}
            }]
        random_order = kwargs.get('random_order', False)
        someof = kwargs.get('someof', 0)
        self.augmenter = AugmenterBuilder().build(augmenter_args, random_order, someof)
        self.prob = prob

    def __call__(self, data):
        if np.random.uniform() > self.prob:
            return

        image = data['image']
        shape = image.shape

        if self.augmenter:
            aug = self.augmenter.to_deterministic()
            data['image'] = aug.augment_image(image)
            data = self.may_augment_annotation(aug, data, shape)
        return data

    def may_augment_annotation(self, aug, data, shape):
        if aug is None:
            return data
        if data.get('polys') is None:
            return data

        line_polys = []
        for poly in data['polys']:
            new_poly = self.may_augment_poly(aug, shape, poly)
            line_polys.append(new_poly)
        data['polys'] = np.array(line_polys)
        return data

    def may_augment_poly(self, aug, img_shape, poly):
        keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
        keypoints = aug.augment_keypoints(
            [imgaug.KeypointsOnImage(
                keypoints, shape=img_shape)])[0].keypoints
        poly = [(p.x, p.y) for p in keypoints]
        return poly
